{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用BP神经网络学习算法建立Sepal.Length、Sepal.Width、Petal.Length对Petal.Width回归问题的神经网络模型，首先进行数据准备和初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# 准备基础数据\n",
    "iris = pd.read_csv(\"http://image.cador.cn/data/iris.csv\")\n",
    "x,y = iris.drop(columns=['Species','Petal.Width']),iris['Petal.Width']\n",
    "\n",
    "# 标准化处理\n",
    "x = x.apply(lambda v:(v-np.mean(v))/np.std(v))\n",
    "x = np.c_[[1]*x.shape[0],x]\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=1)\n",
    "\n",
    "# 设定学习效率 alpha\n",
    "alpha = 0.01\n",
    "\n",
    "# 评估隐含层神经元个数\n",
    "m = int(np.round(np.sqrt(0.43*1*4+0.12+2.54*4+0.77+0.35)+0.51,0))\n",
    "\n",
    "# 初始化输入向量的权重矩阵\n",
    "wInput = np.random.uniform(-1,1,(m,4))\n",
    "\n",
    "# 初始化隐含层到输出的权重向量\n",
    "wHide = np.random.uniform(-1,1,m)\n",
    "epsilon = 1e-3\n",
    "errorList = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "权重矩阵wInput的行表示隐含层每个神经元对应输入的权重向量，而列表示输入层的每个神经元到应隐含层的权重向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0 error: 141.83755252818656\n",
      "iter: 1 error: 45.64650176076224\n",
      "iter: 2 error: 14.778942978562807\n",
      "iter: 3 error: 9.614196163924671\n",
      "iter: 4 error: 8.75217771358735\n",
      "iter: 5 error: 8.349814267796166\n",
      "iter: 6 error: 8.040623878098147\n",
      "iter: 7 error: 7.772800204247519\n",
      "iter: 8 error: 7.532654528607033\n",
      "iter: 9 error: 7.314046322298962\n",
      "iter: 10 error: 7.113029312738285\n",
      "iter: 11 error: 6.926582553460887\n",
      "iter: 12 error: 6.752289274397546\n",
      "iter: 13 error: 6.588165293935329\n",
      "iter: 14 error: 6.432578592211868\n",
      "iter: 15 error: 6.284184606823426\n",
      "iter: 16 error: 6.141871367400217\n",
      "iter: 17 error: 6.00474878849802\n",
      "iter: 18 error: 5.87206759675209\n",
      "iter: 19 error: 5.743233797874961\n",
      "iter: 20 error: 5.617785486093765\n",
      "iter: 21 error: 5.4953701511178785\n",
      "iter: 22 error: 5.375723787654714\n",
      "iter: 23 error: 5.25871632964895\n",
      "iter: 24 error: 5.144184318051799\n",
      "iter: 25 error: 5.032195765784499\n",
      "iter: 26 error: 4.922761880559512\n",
      "iter: 27 error: 4.815992769876716\n",
      "iter: 28 error: 4.712043016204116\n",
      "iter: 29 error: 4.611094440739704\n",
      "iter: 30 error: 4.5133276262399\n",
      "iter: 31 error: 4.419126312127618\n",
      "iter: 32 error: 4.328543413410128\n",
      "iter: 33 error: 4.241869178380779\n",
      "iter: 34 error: 4.159339781542005\n",
      "iter: 35 error: 4.081107037337377\n",
      "iter: 36 error: 4.007320581755775\n",
      "iter: 37 error: 3.938036319828986\n",
      "iter: 38 error: 3.8733655815638195\n",
      "iter: 39 error: 3.813241241122353\n",
      "iter: 40 error: 3.757556693270715\n",
      "iter: 41 error: 3.706296098080222\n",
      "iter: 42 error: 3.6591762848465743\n",
      "iter: 43 error: 3.616073881838798\n",
      "iter: 44 error: 3.5767281409377314\n",
      "iter: 45 error: 3.5408776462863174\n",
      "iter: 46 error: 3.508283260149281\n",
      "iter: 47 error: 3.478616051776632\n",
      "iter: 48 error: 3.4516722431614526\n",
      "iter: 49 error: 3.427139308272452\n",
      "iter: 50 error: 3.404810728882119\n",
      "iter: 51 error: 3.384410467210233\n",
      "iter: 52 error: 3.365703021950399\n",
      "iter: 53 error: 3.3486157670477206\n",
      "iter: 54 error: 3.3328338683978354\n",
      "iter: 55 error: 3.3182700564408445\n",
      "iter: 56 error: 3.30476237708151\n",
      "iter: 57 error: 3.2921987816935308\n",
      "iter: 58 error: 3.280467792163049\n",
      "iter: 59 error: 3.2694974680353615\n",
      "iter: 60 error: 3.259153189988303\n",
      "iter: 61 error: 3.249410977741892\n",
      "iter: 62 error: 3.2402011784284652\n",
      "iter: 63 error: 3.2314963407237633\n",
      "iter: 64 error: 3.2232070764129452\n",
      "iter: 65 error: 3.2153177179057497\n",
      "iter: 66 error: 3.207811642304872\n",
      "iter: 67 error: 3.2006092423204198\n",
      "iter: 68 error: 3.1937168492583097\n",
      "iter: 69 error: 3.1871096494438573\n",
      "iter: 70 error: 3.180769385127915\n",
      "iter: 71 error: 3.174672175062576\n",
      "iter: 72 error: 3.1688499063538482\n",
      "iter: 73 error: 3.163196026792294\n",
      "iter: 74 error: 3.157770720974648\n",
      "iter: 75 error: 3.1525258554136455\n",
      "iter: 76 error: 3.147463297968372\n",
      "iter: 77 error: 3.1425717099343577\n",
      "iter: 78 error: 3.137841675396403\n",
      "iter: 79 error: 3.133264599604336\n",
      "iter: 80 error: 3.1288325119802263\n",
      "iter: 81 error: 3.124537987656333\n",
      "iter: 82 error: 3.1203740920019034\n",
      "iter: 83 error: 3.1163343348806594\n",
      "iter: 84 error: 3.1124126316939424\n",
      "iter: 85 error: 3.108603269808055\n",
      "iter: 86 error: 3.1049008793611965\n",
      "iter: 87 error: 3.101300407656968\n",
      "iter: 88 error: 3.097797096508078\n",
      "iter: 89 error: 3.0943864620193646\n",
      "iter: 90 error: 3.0910642764005596\n",
      "iter: 91 error: 3.0878480605507224\n",
      "iter: 92 error: 3.0846797354289754\n",
      "iter: 93 error: 3.081597391989877\n",
      "iter: 94 error: 3.0785894570906938\n",
      "iter: 95 error: 3.075652133243253\n",
      "iter: 96 error: 3.0728052025916\n",
      "iter: 97 error: 3.069985260218296\n",
      "iter: 98 error: 3.0672380155451133\n",
      "iter: 99 error: 3.064551333528501\n",
      "iter: 100 error: 3.0619219160784636\n",
      "iter: 101 error: 3.0593474482532708\n",
      "iter: 102 error: 3.0568257994952797\n",
      "iter: 103 error: 3.0543549522840605\n",
      "iter: 104 error: 3.0519329945922395\n",
      "iter: 105 error: 3.0495920676080597\n",
      "iter: 106 error: 3.0472554915194032\n",
      "iter: 107 error: 3.0449988389682354\n",
      "iter: 108 error: 3.0427512642952874\n",
      "iter: 109 error: 3.0405528867764624\n",
      "iter: 110 error: 3.0383944714828446\n",
      "iter: 111 error: 3.0362740578540826\n",
      "iter: 112 error: 3.0341903860430186\n",
      "iter: 113 error: 3.032142285228146\n",
      "iter: 114 error: 3.030128634333067\n",
      "iter: 115 error: 3.0281483632515314\n",
      "iter: 116 error: 3.0262004525525628\n",
      "iter: 117 error: 3.024283931462303\n",
      "iter: 118 error: 3.0223978753108693\n",
      "iter: 119 error: 3.0205414029175066\n",
      "iter: 120 error: 3.0187136740641622\n",
      "iter: 121 error: 3.0169138870997747\n",
      "iter: 122 error: 3.0151412766835963\n",
      "iter: 123 error: 3.0133951116644853\n",
      "iter: 124 error: 3.011674693090436\n",
      "iter: 125 error: 3.009979352341009\n",
      "iter: 126 error: 3.0083084493757712\n",
      "iter: 127 error: 3.006661371091807\n",
      "iter: 128 error: 3.0050375297837615\n",
      "iter: 129 error: 3.0034363617001687\n",
      "iter: 130 error: 3.001857325690202\n",
      "iter: 131 error: 3.0002999019352448\n",
      "iter: 132 error: 2.9987635907600443\n",
      "iter: 133 error: 2.997247911518512\n",
      "iter: 134 error: 2.9957524015494212\n",
      "iter: 135 error: 2.9942766151977\n",
      "iter: 136 error: 2.9928201228971014\n",
      "iter: 137 error: 2.991382510310414\n",
      "iter: 138 error: 2.9899633775234964\n",
      "iter: 139 error: 2.988562338289721\n",
      "iter: 140 error: 2.9871790193216587\n",
      "iter: 141 error: 2.9858130596268215\n",
      "iter: 142 error: 2.9844641098848017\n",
      "iter: 143 error: 2.983131831863043\n",
      "iter: 144 error: 2.9818158978686755\n",
      "iter: 145 error: 2.980515990234295\n",
      "iter: 146 error: 2.979231800835244\n",
      "iter: 147 error: 2.977963030636533\n",
      "iter: 148 error: 2.976694234863014\n",
      "iter: 149 error: 2.975456809107571\n",
      "iter: 150 error: 2.97423734092027\n",
      "iter: 151 error: 2.973032995007187\n",
      "iter: 152 error: 2.9718431099531486\n",
      "iter: 153 error: 2.970667326517445\n",
      "iter: 154 error: 2.9695053462091376\n",
      "iter: 155 error: 2.9683568922884658\n",
      "iter: 156 error: 2.967197464917074\n",
      "iter: 157 error: 2.9660651763638874\n",
      "iter: 158 error: 2.9649479593907646\n",
      "iter: 159 error: 2.963843477099628\n",
      "iter: 160 error: 2.962751255857821\n",
      "iter: 161 error: 2.9616710485052193\n",
      "iter: 162 error: 2.9606026438219915\n",
      "iter: 163 error: 2.959545841123406\n",
      "iter: 164 error: 2.9585004457696455\n",
      "iter: 165 error: 2.957466267984824\n",
      "iter: 166 error: 2.956443122398256\n",
      "iter: 167 error: 2.9554308277951384\n",
      "iter: 168 error: 2.954429206940582\n",
      "iter: 169 error: 2.9534380864337098\n"
     ]
    }
   ],
   "source": [
    "# 进入迭代\n",
    "for p in range(1000):\n",
    "    error = 0\n",
    "    for i in range(x_train.shape[0]):\n",
    "        # 正向传播过程\n",
    "        xInput = x_train[i,:]\n",
    "        d = np.matmul(wInput,xInput)\n",
    "        z = (np.exp(d)-np.exp(-d))/(np.exp(d)+np.exp(-d))\n",
    "        o = np.matmul(wHide,z)\n",
    "        e = o - y_train.values[i]\n",
    "        error = error + e**2\n",
    "        \n",
    "        # 若 e>epsilon，则进入反向传播过程\n",
    "        if np.abs(e) > epsilon:\n",
    "            wHide = wHide - alpha*z*e\n",
    "            a = (4*np.exp(2*d)/((np.exp(2*d)+1)**2))*wHide*alpha*e\n",
    "            wInput = wInput - [x*xInput for x in a]\n",
    "            \n",
    "    errorList.append(error)\n",
    "    print(\"iter:\",p,\"error:\",error)\n",
    "    # 当连续两次残差平方和的差小于epsilon时，退出循环\n",
    "    if len(errorList) > 2 and errorList[-2] - errorList[-1] < epsilon:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看最终的残差平方和以及输入层到隐含层和隐含层到输出层的权值向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.9534380864337098"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.28168746, -0.7399645 , -0.31006486, -0.92081159],\n",
       "       [-0.20040921, -0.17269222, -0.20130004,  1.15218551],\n",
       "       [ 0.86645462,  0.20299966,  0.88272144, -0.8343931 ],\n",
       "       [ 1.11096542, -0.25095136,  0.1185884 ,  0.21807303]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wInput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.07675219,  1.20997666,  0.33624993,  1.33659707])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wHide"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现基于学习出来的权重参数对x_test数据集进行预测，并计算残差平方和"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.224847843625698"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = []\n",
    "for i in range(x_test.shape[0]):\n",
    "    # 正向传播过程\n",
    "    xInput = x_test[i,:]\n",
    "    d = np.matmul(wInput,xInput)\n",
    "    z = (np.exp(d)-np.exp(-d))/(np.exp(d)+np.exp(-d))\n",
    "    o = np.matmul(wHide,z)\n",
    "    y_pred.append(o)\n",
    "    \n",
    "np.sum((y_test.values - y_pred)**2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
